//===- OpPythonBindingGen.cpp - Generator of Python API for MLIR Ops ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpPythonBindingGen uses ODS specification of MLIR ops to generate Python
// binding classes wrapping a generic operation API.
//
//===----------------------------------------------------------------------===//

#include "OpGenHelpers.h"

#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
#include <regex>

using namespace mlir;
using namespace mlir::tblgen;
using llvm::formatv;
using llvm::Record;
using llvm::RecordKeeper;

/// File header and includes.
///   {0} is the dialect namespace.
constexpr const char *fileHeader = R"Py(
# Autogenerated by mlir-tblgen; don't manually edit.

from ._ods_common import _cext as _ods_cext
from ._ods_common import (
    equally_sized_accessor as _ods_equally_sized_accessor,
    get_default_loc_context as _ods_get_default_loc_context,
    get_op_results_or_values as _get_op_results_or_values,
    segmented_accessor as _ods_segmented_accessor,
)
_ods_ir = _ods_cext.ir
_ods_cext.globals.register_traceback_file_exclusion(__file__)

import builtins
from typing import Sequence as _Sequence, Union as _Union, Optional as _Optional

)Py";

/// Template for dialect class:
///   {0} is the dialect namespace.
constexpr const char *dialectClassTemplate = R"Py(
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
  DIALECT_NAMESPACE = "{0}"
)Py";

constexpr const char *dialectExtensionTemplate = R"Py(
from ._{0}_ops_gen import _Dialect
)Py";

/// Template for operation class:
///   {0} is the Python class name;
///   {1} is the operation name;
///   {2} is the docstring for this operation.
constexpr const char *opClassTemplate = R"Py(
@_ods_cext.register_operation(_Dialect)
class {0}(_ods_ir.OpView):{2}
  OPERATION_NAME = "{1}"
)Py";

/// Template for class level declarations of operand and result
/// segment specs.
///   {0} is either "OPERAND" or "RESULT"
///   {1} is the segment spec
/// Each segment spec is either None (default) or an array of integers
/// where:
///   1 = single element (expect non sequence operand/result)
///   0 = optional element (expect a value or std::nullopt)
///   -1 = operand/result is a sequence corresponding to a variadic
constexpr const char *opClassSizedSegmentsTemplate = R"Py(
  _ODS_{0}_SEGMENTS = {1}
)Py";

/// Template for class level declarations of the _ODS_REGIONS spec:
///   {0} is the minimum number of regions
///   {1} is the Python bool literal for hasNoVariadicRegions
constexpr const char *opClassRegionSpecTemplate = R"Py(
  _ODS_REGIONS = ({0}, {1})
)Py";

/// Template for single-element accessor:
///   {0} is the name of the accessor;
///   {1} is either 'operand' or 'result';
///   {2} is the position in the element list.
///   {3} is the type hint.
constexpr const char *opSingleTemplate = R"Py(
  @builtins.property
  def {0}(self) -> {3}:
    return self.operation.{1}s[{2}]
)Py";

/// Template for single-element accessor after a variable-length group:
///   {0} is the name of the accessor;
///   {1} is either 'operand' or 'result';
///   {2} is the total number of element groups;
///   {3} is the position of the current group in the group list.
///   {4} is the type hint.
/// This works for both a single variadic group (non-negative length) and an
/// single optional element (zero length if the element is absent).
constexpr const char *opSingleAfterVariableTemplate = R"Py(
  @builtins.property
  def {0}(self) -> {4}:
    _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
    return self.operation.{1}s[{3} + _ods_variadic_group_length - 1]
)Py";

/// Template for an optional element accessor:
///   {0} is the name of the accessor;
///   {1} is either 'operand' or 'result';
///   {2} is the total number of element groups;
///   {3} is the position of the current group in the group list.
///   {4} is the type hint.
/// This works if we have only one variable-length group (and it's the optional
/// operand/result): we can deduce it's absent if the `len(operation.{1}s)` is
/// smaller than the total number of groups.
constexpr const char *opOneOptionalTemplate = R"Py(
  @builtins.property
  def {0}(self) -> _Optional[{4}]:
    return None if len(self.operation.{1}s) < {2} else self.operation.{1}s[{3}]
)Py";

/// Template for the variadic group accessor in the single variadic group case:
///   {0} is the name of the accessor;
///   {1} is either 'operand' or 'result';
///   {2} is the total number of element groups;
///   {3} is the position of the current group in the group list.
///   {4} is the type hint.
constexpr const char *opOneVariadicTemplate = R"Py(
  @builtins.property
  def {0}(self) -> {4}:
    _ods_variadic_group_length = len(self.operation.{1}s) - {2} + 1
    return self.operation.{1}s[{3}:{3} + _ods_variadic_group_length]
)Py";

/// First part of the template for equally-sized variadic group accessor:
///   {0} is the name of the accessor;
///   {1} is either 'operand' or 'result';
///   {2} is the total number of non-variadic groups;
///   {3} is the total number of variadic groups;
///   {4} is the number of non-variadic groups preceding the current group;
///   {5} is the number of variadic groups preceding the current group.
///   {6} is the type hint.
constexpr const char *opVariadicEqualPrefixTemplate = R"Py(
  @builtins.property
  def {0}(self) -> {6}:
    start, elements_per_group = _ods_equally_sized_accessor(self.operation.{1}s, {2}, {3}, {4}, {5}))Py";

/// Second part of the template for equally-sized case, accessing a single
/// element:
///   {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualSimpleTemplate = R"Py(
    return self.operation.{0}s[start]
)Py";

/// Second part of the template for equally-sized case, accessing a variadic
/// group:
///   {0} is either 'operand' or 'result'.
constexpr const char *opVariadicEqualVariadicTemplate = R"Py(
    return self.operation.{0}s[start:start + elements_per_group]
)Py";

/// Template for an attribute-sized group accessor:
///   {0} is the name of the accessor;
///   {1} is either 'operand' or 'result';
///   {2} is the position of the group in the group list;
///   {3} is a return suffix (expected [0] for single-element, empty for
///       variadic, and opVariadicSegmentOptionalTrailingTemplate for optional).
///   {4} is the type hint.
constexpr const char *opVariadicSegmentTemplate = R"Py(
  @builtins.property
  def {0}(self) -> {4}:
    {1}_range = _ods_segmented_accessor(
         self.operation.{1}s,
         self.operation.attributes["{1}SegmentSizes"], {2})
    return {1}_range{3}
)Py";

/// Template for a suffix when accessing an optional element in the
/// attribute-sized case:
///   {0} is either 'operand' or 'result';
constexpr const char *opVariadicSegmentOptionalTrailingTemplate =
    R"Py([0] if len({0}_range) > 0 else None)Py";

/// Template for an operation attribute getter:
///   {0} is the name of the attribute sanitized for Python;
///   {1} is the original name of the attribute.
///   {2} is the type hint.
constexpr const char *attributeGetterTemplate = R"Py(
  @builtins.property
  def {0}(self) -> {2}:
    return self.operation.attributes["{1}"]
)Py";

/// Template for an optional operation attribute getter:
///   {0} is the name of the attribute sanitized for Python;
///   {1} is the original name of the attribute.
///   {2} is the type hint.
constexpr const char *optionalAttributeGetterTemplate = R"Py(
  @builtins.property
  def {0}(self) -> _Optional[{2}]:
    if "{1}" not in self.operation.attributes:
      return None
    return self.operation.attributes["{1}"]
)Py";

/// Template for a getter of a unit operation attribute, returns True of the
/// unit attribute is present, False otherwise (unit attributes have meaning
/// by mere presence):
///    {0} is the name of the attribute sanitized for Python,
///    {1} is the original name of the attribute.
constexpr const char *unitAttributeGetterTemplate = R"Py(
  @builtins.property
  def {0}(self) -> bool:
    return "{1}" in self.operation.attributes
)Py";

/// Template for an operation attribute setter:
///    {0} is the name of the attribute sanitized for Python;
///    {1} is the original name of the attribute.
///    {2} is the type hint.
constexpr const char *attributeSetterTemplate = R"Py(
  @{0}.setter
  def {0}(self, value: {2}):
    if value is None:
      raise ValueError("'None' not allowed as value for mandatory attributes")
    self.operation.attributes["{1}"] = value
)Py";

/// Template for a setter of an optional operation attribute, setting to None
/// removes the attribute:
///    {0} is the name of the attribute sanitized for Python;
///    {1} is the original name of the attribute.
///    {2} is the type hint.
constexpr const char *optionalAttributeSetterTemplate = R"Py(
  @{0}.setter
  def {0}(self, value: _Optional[{2}]):
    if value is not None:
      self.operation.attributes["{1}"] = value
    elif "{1}" in self.operation.attributes:
      del self.operation.attributes["{1}"]
)Py";

/// Template for a setter of a unit operation attribute, setting to None or
/// False removes the attribute:
///    {0} is the name of the attribute sanitized for Python;
///    {1} is the original name of the attribute.
constexpr const char *unitAttributeSetterTemplate = R"Py(
  @{0}.setter
  def {0}(self, value):
    if bool(value):
      self.operation.attributes["{1}"] = _ods_ir.UnitAttr.get()
    elif "{1}" in self.operation.attributes:
      del self.operation.attributes["{1}"]
)Py";

/// Template for a deleter of an optional or a unit operation attribute, removes
/// the attribute from the operation:
///    {0} is the name of the attribute sanitized for Python;
///    {1} is the original name of the attribute.
constexpr const char *attributeDeleterTemplate = R"Py(
  @{0}.deleter
  def {0}(self):
    del self.operation.attributes["{1}"]
)Py";

constexpr const char *regionAccessorTemplate = R"Py(
  @builtins.property
  def {0}(self) -> {2}:
    return self.regions[{1}]
)Py";

constexpr const char *valueBuilderTemplate = R"Py(
def {0}({2}) -> {4}:
  return {1}({3}){5}
)Py";

constexpr const char *valueBuilderVariadicTemplate = R"Py(
def {0}({2}) -> _Union[_ods_ir.OpResult, _ods_ir.OpResultList, {1}]:
  op = {1}({3}); results = op.results
  return results if len(results) > 1 else (results[0] if len(results) == 1 else op)
)Py";

static llvm::cl::OptionCategory
    clOpPythonBindingCat("Options for -gen-python-op-bindings");

static llvm::cl::opt<std::string>
    clDialectName("bind-dialect",
                  llvm::cl::desc("The dialect to run the generator for"),
                  llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));

static llvm::cl::opt<std::string> clDialectExtensionName(
    "dialect-extension", llvm::cl::desc("The prefix of the dialect extension"),
    llvm::cl::init(""), llvm::cl::cat(clOpPythonBindingCat));

using AttributeClasses = DenseMap<StringRef, StringRef>;

/// Checks whether `str` would shadow a generated variable or attribute
/// part of the OpView API.
static bool isODSReserved(StringRef str) {
  static llvm::StringSet<> reserved(
      {"attributes", "create", "context", "ip", "operands", "print", "get_asm",
       "loc", "verify", "regions", "results", "self", "operation",
       "DIALECT_NAMESPACE", "OPERATION_NAME"});
  return str.starts_with("_ods_") || str.ends_with("_ods") ||
         reserved.contains(str);
}

/// Modifies the `name` in a way that it becomes suitable for Python bindings
/// (does not change the `name` if it already is suitable) and returns the
/// modified version.
static std::string sanitizeName(StringRef name) {
  std::string processedStr = name.str();
  std::replace_if(
      processedStr.begin(), processedStr.end(),
      [](char c) { return !llvm::isAlnum(c); }, '_');

  if (llvm::isDigit(*processedStr.begin()))
    return "_" + processedStr;

  if (isPythonReserved(processedStr) || isODSReserved(processedStr))
    return processedStr + "_";
  return processedStr;
}

static std::string attrSizedTraitForKind(const char *kind) {
  return formatv("::mlir::OpTrait::AttrSized{0}{1}Segments",
                 StringRef(kind).take_front().upper(),
                 StringRef(kind).drop_front());
}

/// Emits accessors to "elements" of an Op definition. Currently, the supported
/// elements are operands and results, indicated by `kind`, which must be either
/// `operand` or `result` and is used verbatim in the emitted code.
static void emitElementAccessors(
    const Operator &op, raw_ostream &os, const char *kind,
    unsigned numVariadicGroups, unsigned numElements,
    llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
        getElement) {
  assert(llvm::is_contained(SmallVector<StringRef, 2>{"operand", "result"},
                            kind) &&
         "unsupported kind");

  // Traits indicating how to process variadic elements.
  std::string sameSizeTrait = formatv("::mlir::OpTrait::SameVariadic{0}{1}Size",
                                      StringRef(kind).take_front().upper(),
                                      StringRef(kind).drop_front());
  std::string attrSizedTrait = attrSizedTraitForKind(kind);

  // If there is only one variable-length element group, its size can be
  // inferred from the total number of elements. If there are none, the
  // generation is straightforward.
  if (numVariadicGroups <= 1) {
    bool seenVariableLength = false;
    for (unsigned i = 0; i < numElements; ++i) {
      const NamedTypeConstraint &element = getElement(op, i);
      if (element.isVariableLength())
        seenVariableLength = true;
      if (element.name.empty())
        continue;
      const char *type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                           : "_ods_ir.OpResult";
      if (element.isVariableLength()) {
        if (element.isOptional()) {
          os << formatv(opOneOptionalTemplate, sanitizeName(element.name), kind,
                        numElements, i, type);
        } else {
          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
                                                   : "_ods_ir.OpResultList";
          os << formatv(opOneVariadicTemplate, sanitizeName(element.name), kind,
                        numElements, i, type);
        }
      } else if (seenVariableLength) {
        os << formatv(opSingleAfterVariableTemplate, sanitizeName(element.name),
                      kind, numElements, i, type);
      } else {
        os << formatv(opSingleTemplate, sanitizeName(element.name), kind, i,
                      type);
      }
    }
    return;
  }

  // Handle the operations where variadic groups have the same size.
  if (op.getTrait(sameSizeTrait)) {
    // Count the number of simple elements
    unsigned numSimpleLength = 0;
    for (unsigned i = 0; i < numElements; ++i) {
      const NamedTypeConstraint &element = getElement(op, i);
      if (!element.isVariableLength()) {
        ++numSimpleLength;
      }
    }

    // Generate the accessors
    int numPrecedingSimple = 0;
    int numPrecedingVariadic = 0;
    for (unsigned i = 0; i < numElements; ++i) {
      const NamedTypeConstraint &element = getElement(op, i);
      if (!element.name.empty()) {
        std::string type;
        if (element.isVariableLength()) {
          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.OpOperandList"
                                                   : "_ods_ir.OpResultList";
        } else {
          type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                   : "_ods_ir.OpResult";
        }
        os << formatv(opVariadicEqualPrefixTemplate, sanitizeName(element.name),
                      kind, numSimpleLength, numVariadicGroups,
                      numPrecedingSimple, numPrecedingVariadic, type);
        os << formatv(element.isVariableLength()
                          ? opVariadicEqualVariadicTemplate
                          : opVariadicEqualSimpleTemplate,
                      kind);
      }
      if (element.isVariableLength())
        ++numPrecedingVariadic;
      else
        ++numPrecedingSimple;
    }
    return;
  }

  // Handle the operations where the size of groups (variadic or not) is
  // provided as an attribute. For non-variadic elements, make sure to return
  // an element rather than a singleton container.
  if (op.getTrait(attrSizedTrait)) {
    for (unsigned i = 0; i < numElements; ++i) {
      const NamedTypeConstraint &element = getElement(op, i);
      if (element.name.empty())
        continue;
      std::string trailing;
      std::string type = std::strcmp(kind, "operand") == 0
                             ? "_ods_ir.OpOperandList"
                             : "_ods_ir.OpResultList";
      if (!element.isVariableLength() || element.isOptional()) {
        type = std::strcmp(kind, "operand") == 0 ? "_ods_ir.Value"
                                                 : "_ods_ir.OpResult";
        if (!element.isVariableLength()) {
          trailing = "[0]";
        } else if (element.isOptional()) {
          type = "_Optional[" + type + "]";
          trailing = std::string(
              formatv(opVariadicSegmentOptionalTrailingTemplate, kind));
        }
      }

      os << formatv(opVariadicSegmentTemplate, sanitizeName(element.name), kind,
                    i, trailing, type);
    }
    return;
  }

  llvm::PrintFatalError("unsupported " + llvm::Twine(kind) + " structure");
}

/// Free function helpers accessing Operator components.
static int getNumOperands(const Operator &op) { return op.getNumOperands(); }
static const NamedTypeConstraint &getOperand(const Operator &op, int i) {
  return op.getOperand(i);
}
static int getNumResults(const Operator &op) { return op.getNumResults(); }
static const NamedTypeConstraint &getResult(const Operator &op, int i) {
  return op.getResult(i);
}

/// Emits accessors to Op operands.
static void emitOperandAccessors(const Operator &op, raw_ostream &os) {
  emitElementAccessors(op, os, "operand", op.getNumVariableLengthOperands(),
                       getNumOperands(op), getOperand);
}

/// Emits accessors Op results.
static void emitResultAccessors(const Operator &op, raw_ostream &os) {
  emitElementAccessors(op, os, "result", op.getNumVariableLengthResults(),
                       getNumResults(op), getResult);
}

static std::string getPythonAttrName(mlir::tblgen::Attribute attr) {
  auto storageTypeStr = attr.getStorageType();
  if (storageTypeStr == "::mlir::AffineMapAttr")
    return "AffineMapAttr";
  if (storageTypeStr == "::mlir::ArrayAttr")
    return "ArrayAttr";
  if (storageTypeStr == "::mlir::BoolAttr")
    return "BoolAttr";
  if (storageTypeStr == "::mlir::DenseBoolArrayAttr")
    return "DenseBoolArrayAttr";
  if (storageTypeStr == "::mlir::DenseElementsAttr") {
    llvm::StringSet<> superClasses;
    for (const Record *sc : attr.getDef().getSuperClasses())
      superClasses.insert(sc->getNameInitAsString());
    if (superClasses.contains("FloatElementsAttr") ||
        superClasses.contains("RankedFloatElementsAttr")) {
      return "DenseFPElementsAttr";
    }
    return "DenseElementsAttr";
  }
  if (storageTypeStr == "::mlir::DenseF32ArrayAttr")
    return "DenseF32ArrayAttr";
  if (storageTypeStr == "::mlir::DenseF64ArrayAttr")
    return "DenseF64ArrayAttr";
  if (storageTypeStr == "::mlir::DenseFPElementsAttr")
    return "DenseFPElementsAttr";
  if (storageTypeStr == "::mlir::DenseI16ArrayAttr")
    return "DenseI16ArrayAttr";
  if (storageTypeStr == "::mlir::DenseI32ArrayAttr")
    return "DenseI32ArrayAttr";
  if (storageTypeStr == "::mlir::DenseI64ArrayAttr")
    return "DenseI64ArrayAttr";
  if (storageTypeStr == "::mlir::DenseI8ArrayAttr")
    return "DenseI8ArrayAttr";
  if (storageTypeStr == "::mlir::DenseIntElementsAttr")
    return "DenseIntElementsAttr";
  if (storageTypeStr == "::mlir::DenseResourceElementsAttr")
    return "DenseResourceElementsAttr";
  if (storageTypeStr == "::mlir::DictionaryAttr")
    return "DictAttr";
  if (storageTypeStr == "::mlir::FlatSymbolRefAttr")
    return "FlatSymbolRefAttr";
  if (storageTypeStr == "::mlir::FloatAttr")
    return "FloatAttr";
  if (storageTypeStr == "::mlir::IntegerAttr") {
    if (attr.getAttrDefName().str() == "I1Attr")
      return "BoolAttr";
    return "IntegerAttr";
  }
  if (storageTypeStr == "::mlir::IntegerSetAttr")
    return "IntegerSetAttr";
  if (storageTypeStr == "::mlir::OpaqueAttr")
    return "OpaqueAttr";
  if (storageTypeStr == "::mlir::StridedLayoutAttr")
    return "StridedLayoutAttr";
  if (storageTypeStr == "::mlir::StringAttr")
    return "StringAttr";
  if (storageTypeStr == "::mlir::SymbolRefAttr")
    return "SymbolRefAttr";
  if (storageTypeStr == "::mlir::TypeAttr")
    return "TypeAttr";
  if (storageTypeStr == "::mlir::UnitAttr")
    return "UnitAttr";
  return "Attribute";
}

/// Emits accessors to Op attributes.
static void emitAttributeAccessors(const Operator &op, raw_ostream &os) {
  for (const auto &namedAttr : op.getAttributes()) {
    // Skip "derived" attributes because they are just C++ functions that we
    // don't currently expose.
    if (namedAttr.attr.isDerivedAttr())
      continue;

    if (namedAttr.name.empty())
      continue;

    std::string sanitizedName = sanitizeName(namedAttr.name);

    // Unit attributes are handled specially.
    if (namedAttr.attr.getStorageType().trim() == "::mlir::UnitAttr") {
      os << formatv(unitAttributeGetterTemplate, sanitizedName, namedAttr.name);
      os << formatv(unitAttributeSetterTemplate, sanitizedName, namedAttr.name);
      os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
      continue;
    }

    std::string type = "_ods_ir." + getPythonAttrName(namedAttr.attr);
    if (namedAttr.attr.isOptional()) {
      os << formatv(optionalAttributeGetterTemplate, sanitizedName,
                    namedAttr.name, type);
      os << formatv(optionalAttributeSetterTemplate, sanitizedName,
                    namedAttr.name, type);
      os << formatv(attributeDeleterTemplate, sanitizedName, namedAttr.name);
    } else {
      os << formatv(attributeGetterTemplate, sanitizedName, namedAttr.name,
                    type);
      os << formatv(attributeSetterTemplate, sanitizedName, namedAttr.name,
                    type);
      // Non-optional attributes cannot be deleted.
    }
  }
}

/// Template for the default auto-generated builder.
///   {0} is a comma-separated list of builder arguments, including the trailing
///       `loc` and `ip`;
///   {1} is the code populating `operands`, `results` and `attributes`,
///       `successors` fields.
constexpr const char *initTemplate = R"Py(
  def __init__(self, {0}):
    operands = []
    attributes = {{}
    regions = None
    {1}
    super().__init__({2})
)Py";

/// Template for appending a single element to the operand/result list.
///   {0} is the field name.
constexpr const char *singleOperandAppendTemplate = "operands.append({0})";
constexpr const char *singleResultAppendTemplate = "results.append({0})";

/// Template for appending an optional element to the operand/result list.
///   {0} is the field name.
constexpr const char *optionalAppendOperandTemplate =
    "if {0} is not None: operands.append({0})";
constexpr const char *optionalAppendAttrSizedOperandsTemplate =
    "operands.append({0})";
constexpr const char *optionalAppendResultTemplate =
    "if {0} is not None: results.append({0})";

/// Template for appending a list of elements to the operand/result list.
///   {0} is the field name.
constexpr const char *multiOperandAppendTemplate =
    "operands.extend(_get_op_results_or_values({0}))";
constexpr const char *multiOperandAppendPackTemplate =
    "operands.append(_get_op_results_or_values({0}))";
constexpr const char *multiResultAppendTemplate = "results.extend({0})";

/// Template for attribute builder from raw input in the operation builder.
///   {0} is the builder argument name;
///   {1} is the attribute builder from raw;
///   {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initAttributeWithBuilderTemplate =
    R"Py(attributes["{1}"] = ({0} if (
    isinstance({0}, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('{2}')) else
      _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";

/// Template for attribute builder from raw input for optional attribute in the
/// operation builder.
///   {0} is the builder argument name;
///   {1} is the attribute builder from raw;
///   {2} is the attribute builder from raw.
/// Use the value the user passed in if either it is already an Attribute or
/// there is no method registered to make it an Attribute.
constexpr const char *initOptionalAttributeWithBuilderTemplate =
    R"Py(if {0} is not None: attributes["{1}"] = ({0} if (
        isinstance({0}, _ods_ir.Attribute) or
        not _ods_ir.AttrBuilder.contains('{2}')) else
          _ods_ir.AttrBuilder.get('{2}')({0}, context=_ods_context)))Py";

constexpr const char *initUnitAttributeTemplate =
    R"Py(if bool({1}): attributes["{0}"] = _ods_ir.UnitAttr.get(
      _ods_get_default_loc_context(loc)))Py";

/// Template to initialize the successors list in the builder if there are any
/// successors.
///   {0} is the value to initialize the successors list to.
constexpr const char *initSuccessorsTemplate = R"Py(_ods_successors = {0})Py";

/// Template to append or extend the list of successors in the builder.
///   {0} is the list method ('append' or 'extend');
///   {1} is the value to add.
constexpr const char *addSuccessorTemplate = R"Py(_ods_successors.{0}({1}))Py";

/// Returns true if the SameArgumentAndResultTypes trait can be used to infer
/// result types of the given operation.
static bool hasSameArgumentAndResultTypes(const Operator &op) {
  return op.getTrait("::mlir::OpTrait::SameOperandsAndResultType") &&
         op.getNumVariableLengthResults() == 0;
}

/// Returns true if the FirstAttrDerivedResultType trait can be used to infer
/// result types of the given operation.
static bool hasFirstAttrDerivedResultTypes(const Operator &op) {
  return op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType") &&
         op.getNumVariableLengthResults() == 0;
}

/// Returns true if the InferTypeOpInterface can be used to infer result types
/// of the given operation.
static bool hasInferTypeInterface(const Operator &op) {
  return op.getTrait("::mlir::InferTypeOpInterface::Trait") &&
         op.getNumRegions() == 0;
}

/// Returns true if there is a trait or interface that can be used to infer
/// result types of the given operation.
static bool canInferType(const Operator &op) {
  return hasSameArgumentAndResultTypes(op) ||
         hasFirstAttrDerivedResultTypes(op) || hasInferTypeInterface(op);
}

/// Populates `builderArgs` with result names if the builder is expected to
/// accept them as arguments.
static void
populateBuilderArgsResults(const Operator &op,
                           SmallVectorImpl<std::string> &builderArgs) {
  if (canInferType(op))
    return;

  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
    std::string name = op.getResultName(i).str();
    if (name.empty()) {
      if (op.getNumResults() == 1) {
        // Special case for one result, make the default name be 'result'
        // to properly match the built-in result accessor.
        name = "result";
      } else {
        name = formatv("_gen_res_{0}", i);
      }
    }
    name = sanitizeName(name);
    builderArgs.push_back(name);
  }
}

/// Populates `builderArgs` with the Python-compatible names of builder function
/// arguments using intermixed attributes and operands in the same order as they
/// appear in the `arguments` field of the op definition. Additionally,
/// `operandNames` is populated with names of operands in their order of
/// appearance.
static void populateBuilderArgs(const Operator &op,
                                SmallVectorImpl<std::string> &builderArgs,
                                SmallVectorImpl<std::string> &operandNames) {
  for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
    std::string name = op.getArgName(i).str();
    if (name.empty())
      name = formatv("_gen_arg_{0}", i);
    name = sanitizeName(name);
    builderArgs.push_back(name);
    if (!isa<NamedAttribute *>(op.getArg(i)))
      operandNames.push_back(name);
  }
}

/// Populates `builderArgs` with the Python-compatible names of builder function
/// successor arguments. Additionally, `successorArgNames` is also populated.
static void
populateBuilderArgsSuccessors(const Operator &op,
                              SmallVectorImpl<std::string> &builderArgs,
                              SmallVectorImpl<std::string> &successorArgNames) {

  for (int i = 0, e = op.getNumSuccessors(); i < e; ++i) {
    NamedSuccessor successor = op.getSuccessor(i);
    std::string name = std::string(successor.name);
    if (name.empty())
      name = formatv("_gen_successor_{0}", i);
    name = sanitizeName(name);
    builderArgs.push_back(name);
    successorArgNames.push_back(name);
  }
}

/// Populates `builderLines` with additional lines that are required in the
/// builder to set up operation attributes. `argNames` is expected to contain
/// the names of builder arguments that correspond to op arguments, i.e. to the
/// operands and attributes in the same order as they appear in the `arguments`
/// field.
static void
populateBuilderLinesAttr(const Operator &op, ArrayRef<std::string> argNames,
                         SmallVectorImpl<std::string> &builderLines) {
  builderLines.push_back("_ods_context = _ods_get_default_loc_context(loc)");
  for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
    Argument arg = op.getArg(i);
    auto *attribute = llvm::dyn_cast_if_present<NamedAttribute *>(arg);
    if (!attribute)
      continue;

    // Unit attributes are handled specially.
    if (attribute->attr.getStorageType().trim() == "::mlir::UnitAttr") {
      builderLines.push_back(
          formatv(initUnitAttributeTemplate, attribute->name, argNames[i]));
      continue;
    }

    builderLines.push_back(formatv(
        attribute->attr.isOptional() || attribute->attr.hasDefaultValue()
            ? initOptionalAttributeWithBuilderTemplate
            : initAttributeWithBuilderTemplate,
        argNames[i], attribute->name, attribute->attr.getAttrDefName()));
  }
}

/// Populates `builderLines` with additional lines that are required in the
/// builder to set up successors. successorArgNames is expected to correspond
/// to the Python argument name for each successor on the op.
static void
populateBuilderLinesSuccessors(const Operator &op,
                               ArrayRef<std::string> successorArgNames,
                               SmallVectorImpl<std::string> &builderLines) {
  if (successorArgNames.empty()) {
    builderLines.push_back(formatv(initSuccessorsTemplate, "None"));
    return;
  }

  builderLines.push_back(formatv(initSuccessorsTemplate, "[]"));
  for (int i = 0, e = successorArgNames.size(); i < e; ++i) {
    auto &argName = successorArgNames[i];
    const NamedSuccessor &successor = op.getSuccessor(i);
    builderLines.push_back(formatv(addSuccessorTemplate,
                                   successor.isVariadic() ? "extend" : "append",
                                   argName));
  }
}

/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op operands.
static void
populateBuilderLinesOperand(const Operator &op, ArrayRef<std::string> names,
                            SmallVectorImpl<std::string> &builderLines) {
  bool sizedSegments = op.getTrait(attrSizedTraitForKind("operand")) != nullptr;

  // For each element, find or generate a name.
  for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
    const NamedTypeConstraint &element = op.getOperand(i);
    std::string name = names[i];

    // Choose the formatting string based on the element kind.
    StringRef formatString;
    if (!element.isVariableLength()) {
      formatString = singleOperandAppendTemplate;
    } else if (element.isOptional()) {
      if (sizedSegments) {
        formatString = optionalAppendAttrSizedOperandsTemplate;
      } else {
        formatString = optionalAppendOperandTemplate;
      }
    } else {
      assert(element.isVariadic() && "unhandled element group type");
      // If emitting with sizedSegments, then we add the actual list-typed
      // element. Otherwise, we extend the actual operands.
      if (sizedSegments) {
        formatString = multiOperandAppendPackTemplate;
      } else {
        formatString = multiOperandAppendTemplate;
      }
    }

    builderLines.push_back(formatv(formatString.data(), name));
  }
}

/// Python code template of generating result types for
/// FirstAttrDerivedResultType trait
///   - {0} is the name of the attribute from which to derive the types.
///   - {1} is the number of results.
constexpr const char *firstAttrDerivedResultTypeTemplate =
    R"Py(if results is None:
  _ods_result_type_source_attr = attributes["{0}"]
  _ods_derived_result_type = (
    _ods_ir.TypeAttr(_ods_result_type_source_attr).value
    if _ods_ir.TypeAttr.isinstance(_ods_result_type_source_attr) else
    _ods_result_type_source_attr.type)
  results = [_ods_derived_result_type] * {1})Py";

/// Python code template of generating result types for
/// SameOperandsAndResultType trait
///   - {0} is the number of results.
constexpr const char *sameOperandsAndResultTypeTemplate =
    R"Py(if results is None: results = [operands[0].type] * {0})Py";

/// Appends the given multiline string as individual strings into
/// `builderLines`.
static void appendLineByLine(StringRef string,
                             SmallVectorImpl<std::string> &builderLines) {

  std::pair<StringRef, StringRef> split = std::make_pair(string, string);
  do {
    split = split.second.split('\n');
    builderLines.push_back(split.first.str());
  } while (!split.second.empty());
}

/// Populates `builderLines` with additional lines that are required in the
/// builder to set up op results.
static void
populateBuilderLinesResult(const Operator &op, ArrayRef<std::string> names,
                           SmallVectorImpl<std::string> &builderLines) {
  if (hasSameArgumentAndResultTypes(op)) {
    appendLineByLine(
        formatv(sameOperandsAndResultTypeTemplate, op.getNumResults()).str(),
        builderLines);
    return;
  }

  if (hasFirstAttrDerivedResultTypes(op)) {
    const NamedAttribute &firstAttr = op.getAttribute(0);
    assert(!firstAttr.name.empty() && "unexpected empty name for the attribute "
                                      "from which the type is derived");
    appendLineByLine(formatv(firstAttrDerivedResultTypeTemplate, firstAttr.name,
                             op.getNumResults())
                         .str(),
                     builderLines);
    return;
  }

  if (hasInferTypeInterface(op))
    return;

  bool sizedSegments = op.getTrait(attrSizedTraitForKind("result")) != nullptr;
  builderLines.push_back("results = []");

  // For each element, find or generate a name.
  for (int i = 0, e = op.getNumResults(); i < e; ++i) {
    const NamedTypeConstraint &element = op.getResult(i);
    std::string name = names[i];

    // Choose the formatting string based on the element kind.
    StringRef formatString;
    if (!element.isVariableLength()) {
      formatString = singleResultAppendTemplate;
    } else if (element.isOptional()) {
      formatString = optionalAppendResultTemplate;
    } else {
      assert(element.isVariadic() && "unhandled element group type");
      // If emitting with sizedSegments, then we add the actual list-typed
      // element. Otherwise, we extend the actual operands.
      if (sizedSegments) {
        formatString = singleResultAppendTemplate;
      } else {
        formatString = multiResultAppendTemplate;
      }
    }

    builderLines.push_back(formatv(formatString.data(), name));
  }
}

/// If the operation has variadic regions, adds a builder argument to specify
/// the number of those regions and builder lines to forward it to the generic
/// constructor.
static void populateBuilderRegions(const Operator &op,
                                   SmallVectorImpl<std::string> &builderArgs,
                                   SmallVectorImpl<std::string> &builderLines) {
  if (op.hasNoVariadicRegions())
    return;

  // This is currently enforced when Operator is constructed.
  assert(op.getNumVariadicRegions() == 1 &&
         op.getRegion(op.getNumRegions() - 1).isVariadic() &&
         "expected the last region to be varidic");

  const NamedRegion &region = op.getRegion(op.getNumRegions() - 1);
  std::string name =
      ("num_" + region.name.take_front().lower() + region.name.drop_front())
          .str();
  builderArgs.push_back(name);
  builderLines.push_back(
      formatv("regions = {0} + {1}", op.getNumRegions() - 1, name));
}

/// Emits a default builder constructing an operation from the list of its
/// result types, followed by a list of its operands. Returns vector
/// of fully built functionArgs for downstream users (to save having to
/// rebuild anew).
static SmallVector<std::string> emitDefaultOpBuilder(const Operator &op,
                                                     raw_ostream &os) {
  SmallVector<std::string> builderArgs;
  SmallVector<std::string> builderLines;
  SmallVector<std::string> operandArgNames;
  SmallVector<std::string> successorArgNames;
  builderArgs.reserve(op.getNumOperands() + op.getNumResults() +
                      op.getNumNativeAttributes() + op.getNumSuccessors());
  populateBuilderArgsResults(op, builderArgs);
  size_t numResultArgs = builderArgs.size();
  populateBuilderArgs(op, builderArgs, operandArgNames);
  size_t numOperandAttrArgs = builderArgs.size() - numResultArgs;
  populateBuilderArgsSuccessors(op, builderArgs, successorArgNames);

  populateBuilderLinesOperand(op, operandArgNames, builderLines);
  populateBuilderLinesAttr(op, ArrayRef(builderArgs).drop_front(numResultArgs),
                           builderLines);
  populateBuilderLinesResult(
      op, ArrayRef(builderArgs).take_front(numResultArgs), builderLines);
  populateBuilderLinesSuccessors(op, successorArgNames, builderLines);
  populateBuilderRegions(op, builderArgs, builderLines);

  // Layout of builderArgs vector elements:
  // [ result_args  operand_attr_args successor_args regions ]

  // Determine whether the argument corresponding to a given index into the
  // builderArgs vector is a python keyword argument or not.
  auto isKeywordArgFn = [&](size_t builderArgIndex) -> bool {
    // All result, successor, and region arguments are positional arguments.
    if ((builderArgIndex < numResultArgs) ||
        (builderArgIndex >= (numResultArgs + numOperandAttrArgs)))
      return false;
    // Keyword arguments:
    // - optional named attributes (including unit attributes)
    // - default-valued named attributes
    // - optional operands
    Argument a = op.getArg(builderArgIndex - numResultArgs);
    if (auto *nattr = llvm::dyn_cast_if_present<NamedAttribute *>(a))
      return (nattr->attr.isOptional() || nattr->attr.hasDefaultValue());
    if (auto *ntype = llvm::dyn_cast_if_present<NamedTypeConstraint *>(a))
      return ntype->isOptional();
    return false;
  };

  // StringRefs in functionArgs refer to strings allocated by builderArgs.
  SmallVector<StringRef> functionArgs;

  // Add positional arguments.
  for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
    if (!isKeywordArgFn(i))
      functionArgs.push_back(builderArgs[i]);
  }

  // Add a bare '*' to indicate that all following arguments must be keyword
  // arguments.
  functionArgs.push_back("*");

  // Add a default 'None' value to each keyword arg string, and then add to the
  // function args list.
  for (size_t i = 0, cnt = builderArgs.size(); i < cnt; ++i) {
    if (isKeywordArgFn(i)) {
      builderArgs[i].append("=None");
      functionArgs.push_back(builderArgs[i]);
    }
  }
  if (canInferType(op)) {
    functionArgs.push_back("results=None");
  }
  functionArgs.push_back("loc=None");
  functionArgs.push_back("ip=None");

  SmallVector<std::string> initArgs;
  initArgs.push_back("self.OPERATION_NAME");
  initArgs.push_back("self._ODS_REGIONS");
  initArgs.push_back("self._ODS_OPERAND_SEGMENTS");
  initArgs.push_back("self._ODS_RESULT_SEGMENTS");
  initArgs.push_back("attributes=attributes");
  initArgs.push_back("results=results");
  initArgs.push_back("operands=operands");
  initArgs.push_back("successors=_ods_successors");
  initArgs.push_back("regions=regions");
  initArgs.push_back("loc=loc");
  initArgs.push_back("ip=ip");

  os << formatv(initTemplate, llvm::join(functionArgs, ", "),
                llvm::join(builderLines, "\n    "), llvm::join(initArgs, ", "));
  return llvm::to_vector<8>(
      llvm::map_range(functionArgs, [](StringRef s) { return s.str(); }));
}

static void emitSegmentSpec(
    const Operator &op, const char *kind,
    llvm::function_ref<int(const Operator &)> getNumElements,
    llvm::function_ref<const NamedTypeConstraint &(const Operator &, int)>
        getElement,
    raw_ostream &os) {
  std::string segmentSpec("[");
  for (int i = 0, e = getNumElements(op); i < e; ++i) {
    const NamedTypeConstraint &element = getElement(op, i);
    if (element.isOptional()) {
      segmentSpec.append("0,");
    } else if (element.isVariadic()) {
      segmentSpec.append("-1,");
    } else {
      segmentSpec.append("1,");
    }
  }
  segmentSpec.append("]");

  os << formatv(opClassSizedSegmentsTemplate, kind, segmentSpec);
}

static void emitRegionAttributes(const Operator &op, raw_ostream &os) {
  // Emit _ODS_REGIONS = (min_region_count, has_no_variadic_regions).
  // Note that the base OpView class defines this as (0, True).
  unsigned minRegionCount = op.getNumRegions() - op.getNumVariadicRegions();
  os << formatv(opClassRegionSpecTemplate, minRegionCount,
                op.hasNoVariadicRegions() ? "True" : "False");
}

/// Emits named accessors to regions.
static void emitRegionAccessors(const Operator &op, raw_ostream &os) {
  for (const auto &en : llvm::enumerate(op.getRegions())) {
    const NamedRegion &region = en.value();
    if (region.name.empty())
      continue;

    assert((!region.isVariadic() || en.index() == op.getNumRegions() - 1) &&
           "expected only the last region to be variadic");
    os << formatv(regionAccessorTemplate, sanitizeName(region.name),
                  std::to_string(en.index()) + (region.isVariadic() ? ":" : ""),
                  region.isVariadic() ? "_ods_ir.RegionSequence"
                                      : "_ods_ir.Region");
  }
}

/// Emits builder that extracts results from op
static void emitValueBuilder(const Operator &op,
                             SmallVector<std::string> functionArgs,
                             raw_ostream &os) {
  // Params with (possibly) default args.
  auto valueBuilderParams =
      llvm::map_range(functionArgs, [](const std::string &argAndMaybeDefault) {
        SmallVector<StringRef> argMaybeDefault =
            llvm::to_vector<2>(llvm::split(argAndMaybeDefault, "="));
        auto arg = llvm::convertToSnakeFromCamelCase(argMaybeDefault[0]);
        if (argMaybeDefault.size() == 2)
          return arg + "=" + argMaybeDefault[1].str();
        return arg;
      });
  // Actual args passed to op builder (e.g., opParam=op_param).
  auto opBuilderArgs = llvm::map_range(
      llvm::make_filter_range(functionArgs,
                              [](const std::string &s) { return s != "*"; }),
      [](const std::string &arg) {
        auto lhs = *llvm::split(arg, "=").begin();
        return (lhs + "=" + llvm::convertToSnakeFromCamelCase(lhs)).str();
      });
  std::string nameWithoutDialect = sanitizeName(
      op.getOperationName().substr(op.getOperationName().find('.') + 1));
  if (nameWithoutDialect == op.getCppClassName())
    nameWithoutDialect += "_";
  std::string params = llvm::join(valueBuilderParams, ", ");
  std::string args = llvm::join(opBuilderArgs, ", ");
  if (op.getNumVariableLengthResults()) {
    os << formatv(valueBuilderVariadicTemplate, nameWithoutDialect,
                  op.getCppClassName(), params, args);
  } else {
    std::string type = op.getCppClassName().str();
    const char *results = "";
    if (op.getNumResults() > 1) {
      type = "_ods_ir.OpResultList";
      results = ".results";
    } else if (op.getNumResults() == 1) {
      type = "_ods_ir.OpResult";
      results = ".result";
    }
    os << formatv(valueBuilderTemplate, nameWithoutDialect,
                  op.getCppClassName(), params, args, type, results);
  }
}

/// Retrieve the description of the given op and generate a docstring for it.
static std::string makeDocStringForOp(const Operator &op) {
  if (!op.hasDescription())
    return "";

  auto desc = op.getDescription().rtrim(" \t").str();
  // Replace all """ with \"\"\" to avoid early termination of the literal.
  desc = std::regex_replace(desc, std::regex(R"(""")"), R"(\"\"\")");

  std::string docString = "\n";
  llvm::raw_string_ostream os(docString);
  raw_indented_ostream identedOs(os);
  os << R"(  r""")" << "\n";
  identedOs.printReindented(desc, "  ");
  if (!StringRef(desc).ends_with("\n"))
    os << "\n";
  os << R"(  """)" << "\n";

  return docString;
}

/// Emits bindings for a specific Op to the given output stream.
static void emitOpBindings(const Operator &op, raw_ostream &os) {
  os << formatv(opClassTemplate, op.getCppClassName(), op.getOperationName(),
                makeDocStringForOp(op));

  // Sized segments.
  if (op.getTrait(attrSizedTraitForKind("operand")) != nullptr) {
    emitSegmentSpec(op, "OPERAND", getNumOperands, getOperand, os);
  }
  if (op.getTrait(attrSizedTraitForKind("result")) != nullptr) {
    emitSegmentSpec(op, "RESULT", getNumResults, getResult, os);
  }

  emitRegionAttributes(op, os);
  SmallVector<std::string> functionArgs = emitDefaultOpBuilder(op, os);
  emitOperandAccessors(op, os);
  emitAttributeAccessors(op, os);
  emitResultAccessors(op, os);
  emitRegionAccessors(op, os);
  emitValueBuilder(op, functionArgs, os);
}

/// Emits bindings for the dialect specified in the command line, including file
/// headers and utilities. Returns `false` on success to comply with Tablegen
/// registration requirements.
static bool emitAllOps(const RecordKeeper &records, raw_ostream &os) {
  if (clDialectName.empty())
    llvm::PrintFatalError("dialect name not provided");

  os << fileHeader;
  if (!clDialectExtensionName.empty())
    os << formatv(dialectExtensionTemplate, clDialectName.getValue());
  else
    os << formatv(dialectClassTemplate, clDialectName.getValue());

  for (const Record *rec : records.getAllDerivedDefinitions("Op")) {
    Operator op(rec);
    if (op.getDialectName() == clDialectName.getValue())
      emitOpBindings(op, os);
  }
  return false;
}

static GenRegistration
    genPythonBindings("gen-python-op-bindings",
                      "Generate Python bindings for MLIR Ops", &emitAllOps);
